{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Xmf_JRJa_N8C"
      },
      "source": [
        "<table align=\"center\">\n",
        "  <td align=\"center\"><a target=\"_blank\" href=\"http://introtodeeplearning.com\">\n",
        "        <img src=\"https://i.ibb.co/Jr88sn2/mit.png\" style=\"padding-bottom:5px;\" />\n",
        "      Visit MIT Deep Learning</a></td>\n",
        "  <td align=\"center\"><a target=\"_blank\" href=\"https://colab.research.google.com/github/MITDeepLearning/introtodeeplearning/blob/master/lab2/solutions/PT_Part1_MNIST_Solution.ipynb\">\n",
        "        <img src=\"https://i.ibb.co/2P3SLwK/colab.png\"  style=\"padding-bottom:5px;\" />Run in Google Colab</a></td>\n",
        "  <td align=\"center\"><a target=\"_blank\" href=\"https://github.com/MITDeepLearning/introtodeeplearning/blob/master/lab2/solutions/PT_Part1_MNIST_Solution.ipynb\">\n",
        "        <img src=\"https://i.ibb.co/xfJbPmL/github.png\"  height=\"70px\" style=\"padding-bottom:5px;\"  />View Source on GitHub</a></td>\n",
        "</table>\n",
        "\n",
        "# Copyright Information"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gKA_J7bdP33T"
      },
      "outputs": [],
      "source": [
        "# Copyright 2025 MIT Introduction to Deep Learning. All Rights Reserved.\n",
        "#\n",
        "# Licensed under the MIT License. You may not use this file except in compliance\n",
        "# with the License. Use and/or modification of this code outside of MIT Introduction\n",
        "# to Deep Learning must reference:\n",
        "#\n",
        "# © MIT Introduction to Deep Learning\n",
        "# http://introtodeeplearning.com\n",
        "#"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Cm1XpLftPi4A"
      },
      "source": [
        "# Laboratory 2: Computer Vision\n",
        "\n",
        "# Part 1: MNIST Digit Classification\n",
        "\n",
        "In the first portion of this lab, we will build and train a convolutional neural network (CNN) for classification of handwritten digits from the famous [MNIST](http://yann.lecun.com/exdb/mnist/) dataset. The MNIST dataset consists of 60,000 training images and 10,000 test images. Our classes are the digits 0-9.\n",
        "\n",
        "First, let's download the course repository, install dependencies, and import the relevant packages we'll need for this lab."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RsGqx_ai_N8F"
      },
      "outputs": [],
      "source": [
        "# Import PyTorch and other relevant libraries\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "import torchvision\n",
        "import torchvision.datasets as datasets\n",
        "import torchvision.transforms as transforms\n",
        "from torch.utils.data import DataLoader\n",
        "from torchsummary import summary\n",
        "\n",
        "# MIT introduction to deep learning package\n",
        "!pip install mitdeeplearning --quiet\n",
        "import mitdeeplearning as mdl\n",
        "\n",
        "# other packages\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import random\n",
        "from tqdm import tqdm"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nCpHDxX1bzyZ"
      },
      "source": [
        "We'll also install Comet. If you followed the instructions from Lab 1, you should have your Comet account set up. Enter your API key below."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GSR_PAqjbzyZ"
      },
      "outputs": [],
      "source": [
        "!pip install comet_ml > /dev/null 2>&1\n",
        "import comet_ml\n",
        "# TODO: ENTER YOUR API KEY HERE!!\n",
        "COMET_API_KEY = \"\"\n",
        "\n",
        "# Check that we are using a GPU, if not switch runtimes\n",
        "#   using Runtime > Change Runtime Type > GPU\n",
        "assert torch.cuda.is_available(), \"Please enable GPU from runtime settings\"\n",
        "assert COMET_API_KEY != \"\", \"Please insert your Comet API Key\"\n",
        "\n",
        "# Set GPU for computation\n",
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wGPDtVxvTtPk"
      },
      "outputs": [],
      "source": [
        "# start a first comet experiment for the first part of the lab\n",
        "comet_ml.init(project_name=\"6S191_lab2_part1_NN\")\n",
        "comet_model_1 = comet_ml.Experiment()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HKjrdUtX_N8J"
      },
      "source": [
        "## 1.1 MNIST dataset\n",
        "\n",
        "Let's download and load the dataset and display a few random samples from it:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "G1Bryi5ssUNX"
      },
      "outputs": [],
      "source": [
        "# Download and transform the MNIST dataset\n",
        "transform = transforms.Compose([\n",
        "    # Convert images to PyTorch tensors which also scales data from [0,255] to [0,1]\n",
        "    transforms.ToTensor()\n",
        "])\n",
        "\n",
        "# Download training and test datasets\n",
        "train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)\n",
        "test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "D_AhlQB4sUNX"
      },
      "source": [
        "The MNIST dataset object in PyTorch is not a simple tensor or array. It's an iterable dataset that loads samples (image-label pairs) one at a time or in batches. In a later section of this lab, we will define a handy DataLoader to process the data in batches."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LpxeLuaysUNX"
      },
      "outputs": [],
      "source": [
        "image, label = train_dataset[0]\n",
        "print(image.size())  # For a tensor: torch.Size([1, 28, 28])\n",
        "print(label)  # For a label: integer (e.g., 5)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5ZtUqOqePsRD"
      },
      "source": [
        "Our training set is made up of 28x28 grayscale images of handwritten digits.\n",
        "\n",
        "Let's visualize what some of these images and their corresponding training labels look like."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bDBsR2lP_N8O",
        "scrolled": true
      },
      "outputs": [],
      "source": [
        "plt.figure(figsize=(10,10))\n",
        "random_inds = np.random.choice(60000,36)\n",
        "for i in range(36):\n",
        "    plt.subplot(6, 6, i + 1)\n",
        "    plt.xticks([])\n",
        "    plt.yticks([])\n",
        "    plt.grid(False)\n",
        "    image_ind = random_inds[i]\n",
        "    image, label = train_dataset[image_ind]\n",
        "    plt.imshow(image.squeeze(), cmap=plt.cm.binary)\n",
        "    plt.xlabel(label)\n",
        "comet_model_1.log_figure(figure=plt)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "V6hd3Nt1_N8q"
      },
      "source": [
        "## 1.2 Neural Network for Handwritten Digit Classification\n",
        "\n",
        "We'll first build a simple neural network consisting of two fully connected layers and apply this to the digit classification task. Our network will ultimately output a probability distribution over the 10 digit classes (0-9). This first architecture we will be building is depicted below:\n",
        "\n",
        "![alt_text](https://raw.githubusercontent.com/MITDeepLearning/introtodeeplearning/master/lab2/img/mnist_2layers_arch.png \"CNN Architecture for MNIST Classification\")\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rphS2rMIymyZ"
      },
      "source": [
        "### Fully connected neural network architecture\n",
        "To define the architecture of this first fully connected neural network, we'll once again use the the `torch.nn` modules, defining the model using [`nn.Sequential`](https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html). Note how we first use a [`nn.Flatten`](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Flatten) layer, which flattens the input so that it can be fed into the model.\n",
        "\n",
        "In this next block, you'll define the fully connected layers of this simple network."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MMZsbjAkDKpU"
      },
      "outputs": [],
      "source": [
        "def build_fc_model():\n",
        "    fc_model = nn.Sequential(\n",
        "        # First define a Flatten layer\n",
        "        nn.Flatten(),\n",
        "\n",
        "        nn.Linear(28 * 28, 128),\n",
        "        # '''TODO: Define the activation function.'''\n",
        "        nn.ReLU(),\n",
        "\n",
        "        # '''TODO: Define the second Linear layer'''\n",
        "        nn.Linear(128, 10),\n",
        "        )\n",
        "    return fc_model\n",
        "\n",
        "fc_model_sequential = build_fc_model()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VtGZpHVKz5Jt"
      },
      "source": [
        "As we progress through this next portion, you may find that you'll want to make changes to the architecture defined above. **Note that in order to update the model later on, you'll need to re-run the above cell to re-initialize the model.**"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mVN1_AeG_N9N"
      },
      "source": [
        "Let's take a step back and think about the network we've just created. The first layer in this network, `nn.Flatten`, transforms the format of the images from a 2d-array (28 x 28 pixels), to a 1d-array of 28 * 28 = 784 pixels. You can think of this layer as unstacking rows of pixels in the image and lining them up. There are no learned parameters in this layer; it only reformats the data.\n",
        "\n",
        "After the pixels are flattened, the network consists of a sequence of two `nn.Linear` layers. These are fully-connected neural layers. The first `nn.Linear` layer has 128 nodes (or neurons). The second (and last) layer (which you've defined!) should return an array of probability scores that sum to 1. Each node contains a score that indicates the probability that the current image belongs to one of the handwritten digit classes.\n",
        "\n",
        "That defines our fully connected model!"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kquVpHqPsUNX"
      },
      "source": [
        "### Embracing subclassing in PyTorch"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RyqD3eJgsUNX"
      },
      "source": [
        "Recall that in Lab 1, we explored creating more flexible models by subclassing [`nn.Module`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html). This technique of defining models is more commonly used in PyTorch. We will practice using this approach of subclassing to define our models for the rest of the lab."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7JhFJXjYsUNX"
      },
      "outputs": [],
      "source": [
        "# Define the fully connected model\n",
        "class FullyConnectedModel(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(FullyConnectedModel, self).__init__()\n",
        "        self.flatten = nn.Flatten()\n",
        "        self.fc1 = nn.Linear(28 * 28, 128)\n",
        "\n",
        "        # '''TODO: Define the activation function for the first fully connected layer'''\n",
        "        self.relu = nn.ReLU()\n",
        "\n",
        "        # '''TODO: Define the second Linear layer to output the classification probabilities'''\n",
        "        self.fc2 = nn.Linear(128, 10)\n",
        "        # self.fc2 = # TODO\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = self.flatten(x)\n",
        "        x = self.fc1(x)\n",
        "\n",
        "        # '''TODO: Implement the rest of forward pass of the model using the layers you have defined above'''\n",
        "        x = self.relu(x)\n",
        "        x = self.fc2(x)\n",
        "        # '''TODO'''\n",
        "\n",
        "        return x\n",
        "\n",
        "fc_model = FullyConnectedModel().to(device) # send the model to GPU"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gut8A_7rCaW6"
      },
      "source": [
        "### Model Metrics and Training Parameters\n",
        "\n",
        "Before training the model, we need to define components that govern its performance and guide its learning process. These include the loss function, optimizer, and evaluation metrics:\n",
        "\n",
        "* *Loss function* — This defines how we measure how accurate the model is during training. As was covered in lecture, during training we want to minimize this function, which will \"steer\" the model in the right direction.\n",
        "* *Optimizer* — This defines how the model is updated based on the data it sees and its loss function.\n",
        "* *Metrics* — Here we can define metrics that we want to use to monitor the training and testing steps. In this example, we'll define and take a look at the *accuracy*, the fraction of the images that are correctly classified.\n",
        "\n",
        "We'll start out by using a stochastic gradient descent (SGD) optimizer initialized with a learning rate of 0.1. Since we are performing a categorical classification task, we'll want to use the [cross entropy loss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html).\n",
        "\n",
        "You'll want to experiment with both the choice of optimizer and learning rate and evaluate how these affect the accuracy of the trained model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Lhan11blCaW7"
      },
      "outputs": [],
      "source": [
        "'''TODO: Experiment with different optimizers and learning rates. How do these affect\n",
        "    the accuracy of the trained model? Which optimizers and/or learning rates yield\n",
        "    the best performance?'''\n",
        "# Define loss function and optimizer\n",
        "loss_function = nn.CrossEntropyLoss()\n",
        "optimizer = optim.SGD(fc_model.parameters(), lr=0.1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qKF6uW-BCaW-"
      },
      "source": [
        "### Train the model\n",
        "\n",
        "We're now ready to train our model, which will involve feeding the training data (`train_dataset`) into the model, and then asking it to learn the associations between images and labels. We'll also need to define the batch size and the number of epochs, or iterations over the MNIST dataset, to use during training. This dataset consists of a (image, label) tuples that we will iteratively access in batches.\n",
        "\n",
        "In Lab 1, we saw how we can use the [`.backward()`](https://pytorch.org/docs/stable/generated/torch.Tensor.backward.html) method to optimize losses and train models with stochastic gradient descent. In this section, we will define a function to train the model using `.backward()` and `optimizer.step()` to automatically update our model parameters (weights and biases) as we saw in Lab 1.\n",
        "\n",
        "Recall, we mentioned in Section 1.1 that the MNIST dataset can be accessed iteratively in batches. Here, we will define a PyTorch [`DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) that will enable us to do that."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EFMbIqIvQ2X0"
      },
      "outputs": [],
      "source": [
        "# Create DataLoaders for batch processing\n",
        "BATCH_SIZE = 64\n",
        "trainset_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)\n",
        "testset_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dfnnoDwEsUNY"
      },
      "outputs": [],
      "source": [
        "def train(model, dataloader, criterion, optimizer, epochs):\n",
        "    model.train()  # Set the model to training mode\n",
        "    for epoch in range(epochs):\n",
        "        total_loss = 0\n",
        "        correct_pred = 0\n",
        "        total_pred = 0\n",
        "\n",
        "        for images, labels in trainset_loader:\n",
        "            # Move tensors to GPU so compatible with model\n",
        "            images, labels = images.to(device), labels.to(device)\n",
        "            # Clear gradients before performing backward pass\n",
        "            optimizer.zero_grad()\n",
        "            # Forward pass\n",
        "            outputs = fc_model(images)\n",
        "            # Calculate loss based on model predictions\n",
        "            loss = loss_function(outputs, labels)\n",
        "            # Backpropagate and update model parameters\n",
        "            loss.backward()\n",
        "            optimizer.step()\n",
        "            # multiply loss by total nos. of samples in batch\n",
        "            total_loss += loss.item()*images.size(0)\n",
        "\n",
        "            # Calculate accuracy\n",
        "            predicted = torch.argmax(outputs, dim=1)  # Get predicted class\n",
        "            correct_pred += (predicted == labels).sum().item()  # Count correct predictions\n",
        "            total_pred += labels.size(0) # Count total predictions\n",
        "\n",
        "        # Compute metrics\n",
        "        total_epoch_loss = total_loss / total_pred\n",
        "        epoch_accuracy = correct_pred / total_pred\n",
        "        print(f\"Epoch {epoch + 1}, Loss: {total_epoch_loss}, Accuracy: {epoch_accuracy:.4f}\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kIpdv-H0sUNY"
      },
      "outputs": [],
      "source": [
        "# TODO: Train the model by calling the function appropriately\n",
        "EPOCHS = 5\n",
        "train(fc_model, trainset_loader, loss_function, optimizer, EPOCHS)\n",
        "# train('''TODO''') # TODO\n",
        "\n",
        "comet_model_1.end()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "W3ZVOhugCaXA"
      },
      "source": [
        "As the model trains, the loss and accuracy metrics are displayed. With five epochs and a learning rate of 0.01, this fully connected model should achieve an accuracy of approximatley 0.97 (or 97%) on the training data."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oEw4bZgGCaXB"
      },
      "source": [
        "### Evaluate accuracy on the test dataset\n",
        "\n",
        "Now that we've trained the model, we can ask it to make predictions about a test set that it hasn't seen before. In this example, iterating over the `testset_loader` allows us to access our test images and test labels. And to evaluate accuracy, we can check to see if the model's predictions match the labels from this loader.\n",
        "\n",
        "Since we have now trained the mode, we will use the eval state of the model on the test dataset."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VflXLEeECaXC"
      },
      "outputs": [],
      "source": [
        "'''TODO: Use the model we have defined in its eval state to complete\n",
        "and call the evaluate function, and calculate the accuracy of the model'''\n",
        "\n",
        "def evaluate(model, dataloader, loss_function):\n",
        "    # Evaluate model performance on the test dataset\n",
        "    model.eval()\n",
        "    test_loss = 0\n",
        "    correct_pred = 0\n",
        "    total_pred = 0\n",
        "    # Disable gradient calculations when in inference mode\n",
        "    with torch.no_grad():\n",
        "        for images, labels in testset_loader:\n",
        "            # TODO: ensure evalaution happens on the GPU\n",
        "            images, labels = images.to(device), labels.to(device)\n",
        "            # images, labels = # TODO\n",
        "\n",
        "            # TODO: feed the images into the model and obtain the predictions (forward pass)\n",
        "            outputs = model(images)\n",
        "            # outputs = # TODO\n",
        "\n",
        "            loss = loss_function(outputs, labels)\n",
        "\n",
        "            # TODO: Calculate test loss\n",
        "            test_loss += loss.item() * images.size(0)\n",
        "            # test_loss += # TODO\n",
        "\n",
        "           '''TODO: make a prediction and determine whether it is correct!'''\n",
        "            # TODO: identify the digit with the highest probability prediction for the images in the test dataset.\n",
        "            predicted = torch.argmax(outputs, dim=1)\n",
        "            # predicted = # TODO\n",
        "\n",
        "            # TODO: tally the number of correct predictions\n",
        "            correct_pred += (predicted == labels).sum().item()\n",
        "            # correct_pred += TODO\n",
        "            # TODO: tally the total number of predictions\n",
        "            total_pred += labels.size(0)\n",
        "            # total_pred += TODO\n",
        "\n",
        "    # Compute average loss and accuracy\n",
        "    test_loss /= total_pred\n",
        "    test_acc = correct_pred / total_pred\n",
        "    return test_loss, test_acc\n",
        "\n",
        "# TODO: call the evaluate function to evaluate the trained model!!\n",
        "test_loss, test_acc = evaluate(fc_model, trainset_loader, loss_function)\n",
        "# test_loss, test_acc = # TODO\n",
        "\n",
        "print('Test accuracy:', test_acc)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yWfgsmVXCaXG"
      },
      "source": [
        "You may observe that the accuracy on the test dataset is a little lower than the accuracy on the training dataset. This gap between training accuracy and test accuracy is an example of *overfitting*, when a machine learning model performs worse on new data than on its training data.\n",
        "\n",
        "What is the highest accuracy you can achieve with this first fully connected model? Since the handwritten digit classification task is pretty straightforward, you may be wondering how we can do better...\n",
        "\n",
        "![Deeper...](https://i.kym-cdn.com/photos/images/newsfeed/000/534/153/f87.jpg)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "baIw9bDf8v6Z"
      },
      "source": [
        "## 1.3 Convolutional Neural Network (CNN) for handwritten digit classification"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_J72Yt1o_fY7"
      },
      "source": [
        "As we saw in lecture, convolutional neural networks (CNNs) are particularly well-suited for a variety of tasks in computer vision, and have achieved near-perfect accuracies on the MNIST dataset. We will now build a CNN composed of two convolutional layers and pooling layers, followed by two fully connected layers, and ultimately output a probability distribution over the 10 digit classes (0-9). The CNN we will be building is depicted below:\n",
        "\n",
        "![alt_text](https://raw.githubusercontent.com/MITDeepLearning/introtodeeplearning/master/lab2/img/convnet_fig.png \"CNN Architecture for MNIST Classification\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EEHqzbJJAEoR"
      },
      "source": [
        "### Define the CNN model\n",
        "\n",
        "We'll use the same training and test datasets as before, and proceed similarly as our fully connected network to define and train our new CNN model. To do this we will explore two layers we have not encountered before: you can use  [`nn.Conv2d`](https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html) to define convolutional layers and [`nn.MaxPool2D`](https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html) to define the pooling layers. Use the parameters shown in the network architecture above to define these layers and build the CNN model. You can decide to use `nn.Sequential` or to subclass `nn.Module`based on your preference."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vec9qcJs-9W5"
      },
      "outputs": [],
      "source": [
        "### Basic CNN in PyTorch ###\n",
        "\n",
        "class CNN(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(CNN, self).__init__()\n",
        "        # TODO: Define the first convolutional layer\n",
        "        self.conv1 = nn.Conv2d(1, 24, kernel_size=3)\n",
        "        # self.conv1 = # TODO\n",
        "\n",
        "        # TODO: Define the first max pooling layer\n",
        "        self.pool1 = nn.MaxPool2d(kernel_size=2)\n",
        "        # self.pool1 = # TODO\n",
        "\n",
        "        # TODO: Define the second convolutional layer\n",
        "        self.conv2 = nn.Conv2d(24, 36, kernel_size=3)\n",
        "        # self.conv2 = # TODO\n",
        "\n",
        "        # TODO: Define the second max pooling layer\n",
        "        self.pool2 = nn.MaxPool2d(kernel_size=2)\n",
        "        # self.pool2 = # TODO\n",
        "\n",
        "        self.flatten = nn.Flatten()\n",
        "        self.fc1 = nn.Linear(36 * 5 * 5, 128)\n",
        "        self.relu = nn.ReLU()\n",
        "\n",
        "        # TODO: Define the Linear layer that outputs the classification\n",
        "        # logits over class labels. Remember that CrossEntropyLoss operates over logits.\n",
        "        self.fc2 = nn.Linear(128, 10)\n",
        "        # self.fc2 = # TODO\n",
        "\n",
        "\n",
        "    def forward(self, x):\n",
        "        # First convolutional and pooling layers\n",
        "        x = self.conv1(x)\n",
        "        x = self.relu(x)\n",
        "        x = self.pool1(x)\n",
        "\n",
        "        # '''TODO: Implement the rest of forward pass of the model using the layers you have defined above'''\n",
        "        #     '''hint: this will involve another set of convolutional/pooling layers and then the linear layers'''\n",
        "        x = self.conv2(x)\n",
        "        x = self.relu(x)\n",
        "        x = self.pool2(x)\n",
        "\n",
        "        x = self.flatten(x)\n",
        "        x = self.fc1(x)\n",
        "        x = self.relu(x)\n",
        "        x = self.fc2(x)\n",
        "\n",
        "        return x\n",
        "\n",
        "# Instantiate the model\n",
        "cnn_model = CNN().to(device)\n",
        "# Initialize the model by passing some data through\n",
        "image, label = train_dataset[0]\n",
        "image = image.to(device).unsqueeze(0)  # Add batch dimension → Shape: (1, 1, 28, 28)\n",
        "output = cnn_model(image)\n",
        "# Print the model summary\n",
        "print(cnn_model)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kUAXIBynCih2"
      },
      "source": [
        "### Train and test the CNN model\n",
        "\n",
        "Earlier in the lab, we defined a `train` function. The body of the function is quite useful because it allows us to have control over the training model, and to record differentiation operations during training by computing the gradients using `loss.backward()`. You may recall seeing this in Lab 1 Part 1.\n",
        "\n",
        "We'll use this same framework to train our `cnn_model` using stochastic gradient descent. You are free to implement the following parts with or without the train and evaluate functions we defined above. What is most important is understanding how to manipulate the bodies of those functions to train and test models.\n",
        "\n",
        "As we've done above, we can define the loss function, optimizer, and calculate the accuracy of the model. Define an optimizer and learning rate of choice. Feel free to modify as you see fit to optimize your model's performance."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vheyanDkCg6a"
      },
      "outputs": [],
      "source": [
        "# Rebuild the CNN model\n",
        "cnn_model = CNN().to(device)\n",
        "\n",
        "# Define hyperparams\n",
        "batch_size = 64\n",
        "epochs = 7\n",
        "optimizer = optim.SGD(cnn_model.parameters(), lr=1e-2)\n",
        "\n",
        "# TODO: instantiate the cross entropy loss function\n",
        "loss_function = nn.CrossEntropyLoss()\n",
        "# loss_function = # TODO\n",
        "\n",
        "# Redefine trainloader with new batch size parameter (tweak as see fit if optimizing)\n",
        "trainset_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
        "testset_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bzgOEAXVsUNZ"
      },
      "outputs": [],
      "source": [
        "loss_history = mdl.util.LossHistory(smoothing_factor=0.95) # to record the evolution of the loss\n",
        "plotter = mdl.util.PeriodicPlotter(sec=2, xlabel='Iterations', ylabel='Loss', scale='semilogy')\n",
        "\n",
        "# Initialize new comet experiment\n",
        "comet_ml.init(project_name=\"6.s191lab2_part1_CNN\")\n",
        "comet_model_2 = comet_ml.Experiment()\n",
        "\n",
        "if hasattr(tqdm, '_instances'): tqdm._instances.clear() # clear if it exists\n",
        "\n",
        "# Training loop!\n",
        "cnn_model.train()\n",
        "\n",
        "for epoch in range(epochs):\n",
        "    total_loss = 0\n",
        "    correct_pred = 0\n",
        "    total_pred = 0\n",
        "\n",
        "    # First grab a batch of training data which our data loader returns as a tensor\n",
        "    for idx, (images, labels) in enumerate(tqdm(trainset_loader)):\n",
        "        images, labels = images.to(device), labels.to(device)\n",
        "\n",
        "        # Forward pass\n",
        "        #'''TODO: feed the images into the model and obtain the predictions'''\n",
        "        logits = cnn_model(images)\n",
        "        # logits = # TODO\n",
        "\n",
        "        #'''TODO: compute the categorical cross entropy loss\n",
        "        loss = loss_function(logits, labels)\n",
        "        # loss = # TODO\n",
        "        # Get the loss and log it to comet and the loss_history record\n",
        "        loss_value = loss.item()\n",
        "        comet_model_2.log_metric(\"loss\", loss_value, step=idx)\n",
        "        loss_history.append(loss_value) # append the loss to the loss_history record\n",
        "        plotter.plot(loss_history.get())\n",
        "\n",
        "        # Backpropagation/backward pass\n",
        "        '''TODO: Compute gradients for all model parameters and propagate backwads\n",
        "            to update model parameters. remember to reset your optimizer!'''\n",
        "        optimizer.zero_grad()\n",
        "        loss.backward()\n",
        "        optimizer.step()\n",
        "\n",
        "        # Get the prediction and tally metrics\n",
        "        predicted = torch.argmax(logits, dim=1)\n",
        "        correct_pred += (predicted == labels).sum().item()\n",
        "        total_pred += labels.size(0)\n",
        "\n",
        "    # Compute metrics\n",
        "    total_epoch_loss = total_loss / total_pred\n",
        "    epoch_accuracy = correct_pred / total_pred\n",
        "    print(f\"Epoch {epoch + 1}, Loss: {total_epoch_loss}, Accuracy: {epoch_accuracy:.4f}\")\n",
        "\n",
        "comet_model_2.log_figure(figure=plt)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UG3ZXwYOsUNZ"
      },
      "source": [
        "### Evaluate the CNN Model\n",
        "\n",
        "Now that we've trained the model, let's evaluate it on the test dataset."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JDm4znZcDtNl"
      },
      "outputs": [],
      "source": [
        "'''TODO: Evaluate the CNN model!'''\n",
        "\n",
        "test_loss, test_acc = evaluate(cnn_model, trainset_loader, loss_function)\n",
        "# test_loss, test_acc = # TODO\n",
        "\n",
        "print('Test accuracy:', test_acc)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2rvEgK82Glv9"
      },
      "source": [
        "What is the highest accuracy you're able to achieve using the CNN model, and how does the accuracy of the CNN model compare to the accuracy of the simple fully connected network? What optimizers and learning rates seem to be optimal for training the CNN model?\n",
        "\n",
        "Feel free to click the Comet links to investigate the training/accuracy curves for your model."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xsoS7CPDCaXH"
      },
      "source": [
        "### Make predictions with the CNN model\n",
        "\n",
        "With the model trained, we can use it to make predictions about some images."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Gl91RPhdCaXI"
      },
      "outputs": [],
      "source": [
        "test_image, test_label = test_dataset[0]\n",
        "test_image = test_image.to(device).unsqueeze(0)\n",
        "\n",
        "# put the model in evaluation (inference) mode\n",
        "cnn_model.eval()\n",
        "predictions_test_image = cnn_model(test_image)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "x9Kk1voUCaXJ"
      },
      "source": [
        "With this function call, the model has predicted the label of the first image in the testing set. Let's take a look at the prediction:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3DmJEUinCaXK"
      },
      "outputs": [],
      "source": [
        "predictions_test_image"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-hw1hgeSCaXN"
      },
      "source": [
        "As you can see, a prediction is an array of 10 numbers. Recall that the output of our model is a  distribution over the 10 digit classes. Thus, these numbers describe the model's predicted likelihood that the image corresponds to each of the 10 different digits.\n",
        "\n",
        "Let's look at the digit that has the highest likelihood for the first image in the test dataset:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qsqenuPnCaXO"
      },
      "outputs": [],
      "source": [
        "'''TODO: identify the digit with the highest likelihood prediction for the first\n",
        "    image in the test dataset. '''\n",
        "predictions_value = predictions_test_image.cpu().detach().numpy() #.cpu() to copy tensor to memory first\n",
        "prediction = np.argmax(predictions_value)\n",
        "# prediction = # TODO\n",
        "print(prediction)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "E51yS7iCCaXO"
      },
      "source": [
        "So, the model is most confident that this image is a \"???\". We can check the test label (remember, this is the true identity of the digit) to see if this prediction is correct:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Sd7Pgsu6CaXP"
      },
      "outputs": [],
      "source": [
        "print(\"Label of this digit is:\", test_label)\n",
        "plt.imshow(test_image[0,0,:,:].cpu(), cmap=plt.cm.binary)\n",
        "comet_model_2.log_figure(figure=plt)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ygh2yYC972ne"
      },
      "source": [
        "It is! Let's visualize the classification results on the MNIST dataset. We will plot images from the test dataset along with their predicted label, as well as a histogram that provides the prediction probabilities for each of the digits.\n",
        "\n",
        "Recall that in PyTorch the MNIST dataset is typically accessed using a DataLoader to iterate through the test set in smaller, manageable batches. By appending the predictions, test labels, and test images from each batch, we will first gradually accumulate all the data needed for visualization into singular variables to observe our model's predictions."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "v6OqZSiAsUNf"
      },
      "outputs": [],
      "source": [
        "# Initialize variables to store all data\n",
        "all_predictions = []\n",
        "all_labels = []\n",
        "all_images = []\n",
        "\n",
        "# Process test set in batches\n",
        "with torch.no_grad():\n",
        "    for images, labels in testset_loader:\n",
        "        outputs = cnn_model(images)\n",
        "\n",
        "        # Apply softmax to get probabilities from the predicted logits\n",
        "        probabilities = torch.nn.functional.softmax(outputs, dim=1)\n",
        "\n",
        "        # Get predicted classes\n",
        "        predicted = torch.argmax(probabilities, dim=1)\n",
        "\n",
        "        all_predictions.append(probabilities)\n",
        "        all_labels.append(labels)\n",
        "        all_images.append(images)\n",
        "\n",
        "all_predictions = torch.cat(all_predictions)  # Shape: (total_samples, num_classes)\n",
        "all_labels = torch.cat(all_labels)            # Shape: (total_samples,)\n",
        "all_images = torch.cat(all_images)            # Shape: (total_samples, 1, 28, 28)\n",
        "\n",
        "# Convert tensors to NumPy for compatibility with plotting functions\n",
        "predictions = all_predictions.cpu().numpy()  # Shape: (total_samples, num_classes)\n",
        "test_labels = all_labels.cpu().numpy()       # Shape: (total_samples,)\n",
        "test_images = all_images.cpu().numpy()       # Shape: (total_samples, 1, 28, 28)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HV5jw-5HwSmO"
      },
      "outputs": [],
      "source": [
        "#@title Change the slider to look at the model's predictions! { run: \"auto\" }\n",
        "\n",
        "image_index = 79 #@param {type:\"slider\", min:0, max:100, step:1}\n",
        "plt.subplot(1,2,1)\n",
        "mdl.lab2.plot_image_prediction(image_index, predictions, test_labels, test_images)\n",
        "plt.subplot(1,2,2)\n",
        "mdl.lab2.plot_value_prediction(image_index, predictions, test_labels)\n",
        "comet_model_2.log_figure(figure=plt)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kgdvGD52CaXR"
      },
      "source": [
        "We can also plot several images along with their predictions, where correct prediction labels are blue and incorrect prediction labels are grey. The number gives the percent confidence (out of 100) for the predicted label. Note the model can be very confident in an incorrect prediction!"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hQlnbqaw2Qu_"
      },
      "outputs": [],
      "source": [
        "# Plots the first X test images, their predicted label, and the true label\n",
        "# Color correct predictions in blue, incorrect predictions in red\n",
        "num_rows = 5\n",
        "num_cols = 4\n",
        "num_images = num_rows*num_cols\n",
        "plt.figure(figsize=(2*2*num_cols, 2*num_rows))\n",
        "for i in range(num_images):\n",
        "  plt.subplot(num_rows, 2*num_cols, 2*i+1)\n",
        "  mdl.lab2.plot_image_prediction(i, predictions, test_labels, test_images)\n",
        "  plt.subplot(num_rows, 2*num_cols, 2*i+2)\n",
        "  mdl.lab2.plot_value_prediction(i, predictions, test_labels)\n",
        "comet_model_2.log_figure(figure=plt)\n",
        "comet_model_2.end()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3cNtDhVaqEdR"
      },
      "source": [
        "## 1.5 Conclusion\n",
        "In this part of the lab, you had the chance to play with different MNIST classifiers with different architectures (fully-connected layers only, CNN), and experiment with how different hyperparameters affect accuracy (learning rate, etc.). The next part of the lab explores another application of CNNs, facial detection, and some drawbacks of AI systems in real world applications, like issues of bias."
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [
        "Xmf_JRJa_N8C"
      ],
      "name": "PT_Part1_MNIST_Solution.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.7"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}